# # import library for usage of GPU
# import os
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
# import library
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from glob import glob
from PIL import Image
from scipy import misc
from sklearn.model_selection import train_test_split
# load in-focus and out-of-focus SEM images
train_images_level_1 = glob('./data_files/out-of-focus_level_1/train/*.*')
train_labels_level_1 = glob('./data_files/infocus_level_1/train/*.*')
test_images_level_1 = glob('./data_files/out-of-focus_level_1/test/*.*')
test_labels_level_1 = glob('./data_files/infocus_level_1/test/*.*')
train_images_level_1.sort()
train_labels_level_1.sort()
test_images_level_1.sort()
test_labels_level_1.sort()
train_images_level_2 = glob('./data_files/out-of-focus_level_2/train/*.*')
train_labels_level_2 = glob('./data_files/infocus_level_2/train/*.*')
test_images_level_2 = glob('./data_files/out-of-focus_level_2/test/*.*')
test_labels_level_2 = glob('./data_files/infocus_level_2/test/*.*')
train_images_level_2.sort()
train_labels_level_2.sort()
test_images_level_2.sort()
test_labels_level_2.sort()
train_images_level_3 = glob('./data_files/out-of-focus_level_3/train/*.*')
train_labels_level_3 = glob('./data_files/infocus_level_3/train/*.*')
test_images_level_3 = glob('./data_files/out-of-focus_level_3/test/*.*')
test_labels_level_3 = glob('./data_files/infocus_level_3/test/*.*')
train_images_level_3.sort()
train_labels_level_3.sort()
test_images_level_3.sort()
test_labels_level_3.sort()
# split the SEM images into train, validation, and test data
train_images_level_1, valid_images_level_1, train_labels_level_1, valid_labels_level_1 = train_test_split(train_images_level_1, train_labels_level_1, test_size=0.10)
train_images_level_2, valid_images_level_2, train_labels_level_2, valid_labels_level_2 = train_test_split(train_images_level_2, train_labels_level_2, test_size=0.10)
train_images_level_3, valid_images_level_3, train_labels_level_3, valid_labels_level_3 = train_test_split(train_images_level_3, train_labels_level_3, test_size=0.10)
train_images = np.concatenate((train_images_level_1, train_images_level_2, train_images_level_3))
train_labels = np.concatenate((train_labels_level_1, train_labels_level_2, train_labels_level_3))
valid_images = np.concatenate((valid_images_level_1, valid_images_level_2, valid_images_level_3))
valid_labels = np.concatenate((valid_labels_level_1, valid_labels_level_2, valid_labels_level_3))
test_images = np.concatenate((test_images_level_1, test_images_level_2, test_images_level_3))
test_labels = np.concatenate((test_labels_level_1, test_labels_level_2, test_labels_level_3))
# data augmentation technique applied for MRN powered by DA
def cutblur(image, label):
cut_size = int(crop_size / 4)
x_size = image.shape[0]
y_size = image.shape[1]
start_x = np.random.randint(x_size-cut_size)
start_y = np.random.randint(y_size-cut_size)
cut_label = label[start_x:start_x+cut_size, start_y:start_y+cut_size].copy()
image[start_x:start_x+cut_size, start_y:start_y+cut_size] = cut_label
return image, label
# implement batch maker used for training phase
crop_size = 256
factor = 2
current_batch = 0
def train_batch_maker(batch_size):
global current_batch
global train_images
global train_labels
if len(train_images) - current_batch >= batch_size:
batch_train_images = train_images[current_batch:current_batch+batch_size]
batch_train_labels = train_labels[current_batch:current_batch+batch_size]
current_batch += batch_size
else :
idx_train = np.arange(len(train_images))
np.random.shuffle(idx_train)
batch_train_images = train_images[idx_train]
batch_train_labels = train_labels[idx_train]
current_batch = 0
batch_train_images = train_images[current_batch:current_batch+batch_size]
batch_train_labels = train_labels[current_batch:current_batch+batch_size]
train_images_coarsest = []
train_images_intermediate = []
train_images_finer = []
train_labels_coarsest = []
train_labels_intermediate = []
train_labels_finer = []
for image, label in zip(batch_train_images, batch_train_labels):
temp_image = Image.open(image)
temp_image = np.array(temp_image)
temp_label = Image.open(label)
temp_label = np.array(temp_label)
x_size = temp_image.shape[0]
y_size = temp_image.shape[1]
start_x = np.random.randint(x_size-crop_size)
start_y = np.random.randint(y_size-crop_size)
temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]
temp_image, temp_label = cutblur(temp_image, temp_label)
temp_image_finer = temp_image.copy()[:,:,np.newaxis]
temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
temp_label_finer = temp_label.copy()[:,:,np.newaxis]
temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
train_images_coarsest.append(temp_image_coarsest / 255.0)
train_images_intermediate.append(temp_image_intermediate / 255.0)
train_images_finer.append(temp_image_finer / 255.0)
train_labels_coarsest.append(temp_label_coarsest / 255.0)
train_labels_intermediate.append(temp_label_intermediate / 255.0)
train_labels_finer.append(temp_label_finer / 255.0)
train_images_coarsest = np.array(train_images_coarsest)
train_images_intermediate = np.array(train_images_intermediate)
train_images_finer = np.array(train_images_finer)
train_labels_coarsest = np.array(train_labels_coarsest)
train_labels_intermediate = np.array(train_labels_intermediate)
train_labels_finer = np.array(train_labels_finer)
return train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer
# show samples of training data for MRN powered by DA
train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer = train_batch_maker(5)
print('Input data for coarsest scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_coarsest[0,:,:,0], cmap = 'gray')
plt.show()
print('Input data for intermediate scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_intermediate[0,:,:,0], cmap = 'gray')
plt.show()
print('Input data for finer scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_images_finer[0,:,:,0], cmap = 'gray')
plt.show()
print('Ground truth data for coarsest scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_coarsest[0,:,:,0], cmap = 'gray')
plt.show()
print('Ground truth data for intermediate scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_intermediate[0,:,:,0], cmap = 'gray')
plt.show()
print('Ground truth data for finer scale network')
plt.figure(figsize=(20,20))
plt.imshow(train_labels_finer[0,:,:,0], cmap = 'gray')
plt.show()
# Preprocess valid data used for validation phase
valid_images_coarsest = []
valid_images_intermediate = []
valid_images_finer = []
valid_labels_coarsest = []
valid_labels_intermediate = []
valid_labels_finer = []
for image, label in zip(valid_images, valid_labels):
temp_image = Image.open(image)
temp_image = np.array(temp_image)
temp_label = Image.open(label)
temp_label = np.array(temp_label)
start_x = 250
start_y = 250
temp_image = temp_image[start_x:start_x+crop_size, start_y:start_y+crop_size]
temp_label = temp_label[start_x:start_x+crop_size, start_y:start_y+crop_size]
temp_image, temp_label = cutblur(temp_image, temp_label)
temp_image_finer = temp_image.copy()[:,:,np.newaxis]
temp_image_intermediate = misc.imresize(temp_image, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
temp_image_coarsest = misc.imresize(temp_image, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
temp_label_finer = temp_label.copy()[:,:,np.newaxis]
temp_label_intermediate = misc.imresize(temp_label, 1.0/factor, interp = 'bicubic')[:,:,np.newaxis]
temp_label_coarsest = misc.imresize(temp_label, 1.0/(factor**2), interp = 'bicubic')[:,:,np.newaxis]
valid_images_coarsest.append(temp_image_coarsest / 255.0)
valid_images_intermediate.append(temp_image_intermediate / 255.0)
valid_images_finer.append(temp_image_finer / 255.0)
valid_labels_coarsest.append(temp_label_coarsest / 255.0)
valid_labels_intermediate.append(temp_label_intermediate / 255.0)
valid_labels_finer.append(temp_label_finer / 255.0)
valid_images_coarsest = np.array(valid_images_coarsest)
valid_images_intermediate = np.array(valid_images_intermediate)
valid_images_finer = np.array(valid_images_finer)
valid_labels_coarsest = np.array(valid_labels_coarsest)
valid_labels_intermediate = np.array(valid_labels_intermediate)
valid_labels_finer = np.array(valid_labels_finer)
# define placeholders
x_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='x_coarsest')
x_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='x_intermediate')
x_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='x_finer')
y_coarsest = tf.placeholder(tf.float32, [None, None, None, 1], name='y_coarsest')
y_intermediate = tf.placeholder(tf.float32, [None, None, None, 1], name='y_intermediate')
y_finer = tf.placeholder(tf.float32, [None, None, None, 1], name='y_finer')
# define modules
def ResidualBlock(x, kernel_size, filters, strides = 1):
skip = x
x = tf.layers.conv2d(x,
kernel_size = kernel_size,
filters = filters,
strides = strides,
padding = 'same',
use_bias = False)
x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
x = tf.layers.conv2d(x,
kernel_size = kernel_size,
filters = filters,
strides = strides,
padding = 'same',
use_bias = False)
x = x + skip
return x
def Upsample2xBlock(x, kernel_size, filters, name, strides = 1):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope:
x = tf.layers.conv2d(x,
kernel_size = kernel_size,
filters = filters,
strides = strides,
padding = 'same')
x = tf.depth_to_space(x, 2)
x = tf.nn.relu(x)
return x
# define subnetworks
def resnet_coarsest(x, num_blocks):
with tf.variable_scope('resnet_coarsest', reuse=tf.AUTO_REUSE) as scope:
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same')
x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
skip = x
for i in range(num_blocks):
x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same',
use_bias = False)
x = x + skip
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 1,
strides = 1,
padding = 'same',
name = 'forward')
return tf.nn.sigmoid(x)
def resnet_intermediate(x, num_blocks):
with tf.variable_scope('resnet_intermediate', reuse=tf.AUTO_REUSE) as scope:
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same')
x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
skip = x
for i in range(num_blocks):
x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same',
use_bias = False)
x = x + skip
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 1,
strides = 1,
padding = 'same',
name = 'forward')
return tf.nn.sigmoid(x)
def resnet_finer(x, num_blocks):
with tf.variable_scope('resnet_finer', reuse=tf.AUTO_REUSE) as scope:
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same')
x = tf.contrib.keras.layers.PReLU(shared_axes = [1,2])(x)
skip = x
for i in range(num_blocks):
x = ResidualBlock(x, kernel_size = 5, filters = 64, strides = 1)
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 64,
strides = 1,
padding = 'same',
use_bias = False)
x = x + skip
x = tf.layers.conv2d(x,
kernel_size = 5,
filters = 1,
strides = 1,
padding = 'same',
name = 'forward')
return tf.nn.sigmoid(x)
# networks
## coarsest level network
refocus_coarsest = resnet_coarsest(x_coarsest, 16)
refocus_coarsest_upconv = Upsample2xBlock(refocus_coarsest, kernel_size = 3, filters = 4, name = 'upconv_for_intermediate')
refocus_coarsest_upconv_concat = tf.concat((refocus_coarsest_upconv, x_intermediate), axis = 3)
## intermediate level network
refocus_intermediate = resnet_intermediate(refocus_coarsest_upconv_concat, 16)
refocus_intermediate_upconv = Upsample2xBlock(refocus_intermediate, kernel_size = 3, filters = 4, name = 'upconv_for_finer')
refocus_intermediate_upconv_concat = tf.concat((refocus_intermediate_upconv, x_finer), axis = 3)
## finer level network
refocus_finer = resnet_finer(refocus_intermediate_upconv_concat, 16)
# loss
loss_coarsest = tf.reduce_mean(tf.abs(y_coarsest - refocus_coarsest))
loss_intermediate = tf.reduce_mean(tf.abs(y_intermediate - refocus_intermediate))
loss_finer = tf.reduce_mean(tf.abs(y_finer - refocus_finer))
# learning rate
LR = 0.00005
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(LR, global_step, 50000, 0.1, staircase = False)
incr_global_step = tf.assign(global_step, global_step + 1)
# variable list
var_coarsest = [var for var in tf.get_collection('trainable_variables') if 'resnet_coarsest' in var.name]
var_intermediate = [var for var in tf.get_collection('trainable_variables') if 'resnet_intermediate' in var.name or 'upconv_for_intermediate' in var.name]
var_finer = [var for var in tf.get_collection('trainable_variables') if 'resnet_finer' in var.name or 'upconv_for_finer' in var.name]
# optimizer
optm_coarsest = tf.train.AdamOptimizer(learning_rate).minimize(loss_coarsest, var_list = var_coarsest)
optm_intermediate = tf.train.AdamOptimizer(learning_rate).minimize(loss_intermediate, var_list = var_intermediate)
optm_finer = tf.train.AdamOptimizer(learning_rate).minimize(loss_finer, var_list = var_finer)
# training parameters
n_iter = 50000
n_prt = 100
n_batch = 5
# open a session for training
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# optimize a model during n_iter
criteria = 10
loss_train_record = []
loss_valid_record = []
for epoch in range(n_iter):
train_images_coarsest, train_images_intermediate, train_images_finer, train_labels_coarsest, train_labels_intermediate, train_labels_finer = train_batch_maker(n_batch)
sess.run([optm_coarsest, optm_intermediate, optm_finer], feed_dict = {x_coarsest: train_images_coarsest,
x_intermediate: train_images_intermediate,
x_finer: train_images_finer,
y_coarsest: train_labels_coarsest,
y_intermediate: train_labels_intermediate,
y_finer: train_labels_finer})
sess.run(incr_global_step)
criteria_temp = sess.run(loss_finer, feed_dict = {x_coarsest: valid_images_coarsest,
x_intermediate: valid_images_intermediate,
x_finer: valid_images_finer,
y_coarsest: valid_labels_coarsest,
y_intermediate: valid_labels_intermediate,
y_finer: valid_labels_finer})
if criteria > criteria_temp:
criteria = criteria_temp
saver.save(sess, './model/MRN.ckpt')
if epoch % n_prt == 0:
loss_train = sess.run(loss_finer, feed_dict = {x_coarsest: train_images_coarsest,
x_intermediate: train_images_intermediate,
x_finer: train_images_finer,
y_coarsest: train_labels_coarsest,
y_intermediate: train_labels_intermediate,
y_finer: train_labels_finer})
loss_valid = sess.run(loss_finer, feed_dict = {x_coarsest: valid_images_coarsest,
x_intermediate: valid_images_intermediate,
x_finer: valid_images_finer,
y_coarsest: valid_labels_coarsest,
y_intermediate: valid_labels_intermediate,
y_finer: valid_labels_finer})
loss_train_record.append(loss_train)
loss_valid_record.append(loss_valid)
print('Epoch:', '%04d' % epoch, 'loss_train: {:.4}'.format(loss_train), 'loss_valid: {:.4}'.format(loss_valid))
refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: train_images_coarsest[:1],
x_intermediate: train_images_intermediate[:1],
x_finer: train_images_finer[:1],
y_coarsest: train_labels_coarsest[:1],
y_intermediate: train_labels_intermediate[:1],
y_finer: train_labels_finer[:1]})
plt.figure(figsize = (5,5))
plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
plt.axis('off')
plt.show()
# load the best model
save_file = './model/MRN.ckpt'
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, save_file)
# evaluate the test images of level 1,2, and 3 by MRN powered by DA
n = np.random.randint(len(test_images))
test_image = Image.open(test_images[n])
test_image = np.array(test_image)
test_label = Image.open(test_labels[n])
test_label = np.array(test_label)
if (test_image.shape[0] / factor**2) % 1 != 0 or (test_image.shape[1] / factor**2) % 1 != 0:
new_x_shape = int(test_image.shape[0] / factor**2) * factor**2
new_y_shape = int(test_image.shape[1] / factor**2) * factor**2
test_image = test_image[:new_x_shape,:new_y_shape]
test_label = test_label[:new_x_shape,:new_y_shape]
test_image_finer = test_image.copy()[np.newaxis,:,:,np.newaxis] / 255
test_image_intermediate = misc.imresize(test_image, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_image_coarsest = misc.imresize(test_image, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_label_finer = test_label.copy()[np.newaxis,:,:,np.newaxis] / 255
test_label_intermediate = misc.imresize(test_label, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_label_coarsest = misc.imresize(test_label, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: test_image_coarsest,
x_intermediate: test_image_intermediate,
x_finer: test_image_finer,
y_coarsest: test_label_coarsest,
y_intermediate: test_label_intermediate,
y_finer: test_label_finer})
# show the out-of-focuse (input), refocus (output), and in-focus (label) images
plt.figure(figsize = (20,20))
plt.imshow(refocus_img[0,:,:,0], cmap = 'gray')
plt.axis('off')
plt.show()
plt.figure(figsize = (20,20))
plt.imshow(test_image[:,:], cmap = 'gray')
plt.axis('off')
plt.show()
plt.figure(figsize = (20,20))
plt.imshow(test_label[:,:], cmap = 'gray')
plt.axis('off')
plt.show()
# refocus non-uniformly defocused image by MRN powered by DA
non_uniform_refocus_list = []
non_uniform_abs_residual_intensity_map_list = []
for i in range(58538 // 5000 + 1):
file_name = './data_files/non_uniformly_defocused_image.png'
test_image = Image.open(file_name)
try:
test_image = np.array(test_image)[:,i * 5000:(i + 1) * 5000 ]
except:
test_image = np.array(test_image)[:,i * 5000:]
if (test_image.shape[0] / factor**2) % 1 != 0 or (test_image.shape[1] / factor**2) % 1 != 0:
new_x_shape = int(test_image.shape[0] / factor**2) * factor**2
new_y_shape = int(test_image.shape[1] / factor**2) * factor**2
test_image = test_image[:new_x_shape,:new_y_shape]
test_image_finer = test_image.copy()[np.newaxis,:,:,np.newaxis] / 255
test_image_intermediate = misc.imresize(test_image, 1.0/factor, interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
test_image_coarsest = misc.imresize(test_image, 1.0/(factor**2), interp = 'bicubic')[np.newaxis,:,:,np.newaxis] / 255
refocus_img = sess.run(refocus_finer, feed_dict = {x_coarsest: test_image_coarsest,
x_intermediate: test_image_intermediate,
x_finer: test_image_finer})
non_uniform_refocus_list.append(refocus_img[0,:,:,0])
non_uniform_abs_residual_intensity_map_list.append(np.abs(refocus_img[0,:,:,0] - test_image[:,:] / 255))
for i in range(len(non_uniform_refocus_list)):
if i == 0:
non_uniform_refocus = non_uniform_refocus_list[i]
else:
non_uniform_refocus = np.hstack((non_uniform_refocus, non_uniform_refocus_list[i]))
for i in range(len(non_uniform_abs_residual_intensity_map_list)):
if i == 0:
non_uniform_abs_residual_intensity_map = non_uniform_abs_residual_intensity_map_list[i]
else:
non_uniform_abs_residual_intensity_map = np.hstack((non_uniform_abs_residual_intensity_map, non_uniform_abs_residual_intensity_map_list[i]))
# show the refocused image & residual intensity map
plt.figure(figsize = (20,20))
plt.imshow(non_uniform_refocus, 'gray')
plt.axis('off')
plt.show()
plt.figure(figsize = (20,20))
plt.imshow(non_uniform_abs_residual_intensity_map)
plt.axis('off')
plt.show()